{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JfOIB1KdkbYW"
      },
      "source": [
        "##### Copyright 2024 The AI Edge Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "Ojb0aXCmBgo7"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M9Y4JZ0ZGoE4"
      },
      "source": [
        "# Super resolution with TensorFlow Lite"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q3FoFSLBjIYK"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/examples/super_resolution/overview\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/super_resolution/overview.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/super_resolution/overview.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/examples/super_resolution/overview.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://tfhub.dev/captain-pool/esrgan-tf2/1\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" /\u003eSee TF Hub model\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-uF3N4BbaMvA"
      },
      "source": [
        "## Overview"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "isbXET4vVHfu"
      },
      "source": [
        "The task of recovering a high resolution (HR) image from its low resolution counterpart is commonly referred to as Single Image Super Resolution (SISR). \n",
        "\n",
        "The model used here is ESRGAN\n",
        "([ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks](https://arxiv.org/abs/1809.00219)). And we are going to use TensorFlow Lite to run inference on the pretrained model.\n",
        "\n",
        "The TFLite model is converted from this\n",
        "[implementation](https://tfhub.dev/captain-pool/esrgan-tf2/1) hosted on TF Hub. Note that the model we converted upsamples a 50x50 low resolution image to a 200x200 high resolution image (scale factor=4). If you want a different input size or scale factor, you need to re-convert or re-train the original model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2dQlTqiffuoU"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qKyMtsGqu3zH"
      },
      "source": [
        "Let's install required libraries first."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7YTT1Rxsw3A9"
      },
      "outputs": [],
      "source": [
        "!pip install matplotlib tensorflow tensorflow-hub"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Clz5Kl97FswD"
      },
      "source": [
        "Import dependencies."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2xh1kvGEBjuP"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_hub as hub\n",
        "import matplotlib.pyplot as plt\n",
        "print(tf.__version__)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i5miVfL4kxTA"
      },
      "source": [
        "Download and convert the ESRGAN model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X5PvXIXRwvHj"
      },
      "outputs": [],
      "source": [
        "model = hub.load(\"https://tfhub.dev/captain-pool/esrgan-tf2/1\")\n",
        "concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]\n",
        "\n",
	"@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)])\n",
        "def f(input):\n",
	"  return concrete_func(input);\n",
        "\n",
        "converter = tf.lite.TFLiteConverter.from_concrete_functions([f.get_concrete_function()], model)\n",
        "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
        "tflite_model = converter.convert()\n",
        "\n",
        "# Save the TF Lite model.\n",
        "with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:\n",
        "  f.write(tflite_model)\n",
        "\n",
        "esrgan_model_path = './ESRGAN.tflite'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jH5-xPkyUEqt"
      },
      "source": [
        "Download a test image (insect head)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "suWiStTWgK6e"
      },
      "outputs": [],
      "source": [
        "test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rgQ4qRuFNpyW"
      },
      "source": [
        "## Generate a super resolution image using TensorFlow Lite"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J9FV4btf02-2"
      },
      "outputs": [],
      "source": [
        "lr = tf.io.read_file(test_img_path)\n",
        "lr = tf.image.decode_jpeg(lr)\n",
        "lr = tf.expand_dims(lr, axis=0)\n",
        "lr = tf.cast(lr, tf.float32)\n",
        "\n",
        "# Load TFLite model and allocate tensors.\n",
        "interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)\n",
        "interpreter.allocate_tensors()\n",
        "\n",
        "# Get input and output tensors.\n",
        "input_details = interpreter.get_input_details()\n",
        "output_details = interpreter.get_output_details()\n",
        "\n",
        "# Run the model\n",
        "interpreter.set_tensor(input_details[0]['index'], lr)\n",
        "interpreter.invoke()\n",
        "\n",
        "# Extract the output and postprocess it\n",
        "output_data = interpreter.get_tensor(output_details[0]['index'])\n",
        "sr = tf.squeeze(output_data, axis=0)\n",
        "sr = tf.clip_by_value(sr, 0, 255)\n",
        "sr = tf.round(sr)\n",
        "sr = tf.cast(sr, tf.uint8)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EwddQrDUNQGO"
      },
      "source": [
        "## Visualize the result"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aasKuozt1gNd"
      },
      "outputs": [],
      "source": [
        "lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)\n",
        "plt.figure(figsize = (1, 1))\n",
        "plt.title('LR')\n",
        "plt.imshow(lr.numpy());\n",
        "\n",
        "plt.figure(figsize=(10, 4))\n",
        "plt.subplot(1, 2, 1)        \n",
        "plt.title(f'ESRGAN (x4)')\n",
        "plt.imshow(sr.numpy());\n",
        "\n",
        "bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)\n",
        "bicubic = tf.cast(bicubic, tf.uint8)\n",
        "plt.subplot(1, 2, 2)   \n",
        "plt.title('Bicubic')\n",
        "plt.imshow(bicubic.numpy());"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0kb-fkogObjq"
      },
      "source": [
        "## Performance Benchmarks"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tNzdgpqTy5P3"
      },
      "source": [
        "Performance benchmark numbers are generated with the tool\n",
        "[described here](https://www.tensorflow.org/lite/performance/benchmarks).\n",
        "\n",
        "\u003ctable\u003e\n",
        "  \u003cthead\u003e\n",
        "    \u003ctr\u003e\n",
        "      \u003cth\u003eModel Name\u003c/th\u003e\n",
        "      \u003cth\u003eModel Size \u003c/th\u003e\n",
        "      \u003cth\u003eDevice \u003c/th\u003e\n",
        "      \u003cth\u003eCPU\u003c/th\u003e\n",
        "      \u003cth\u003eGPU\u003c/th\u003e\n",
        "    \u003c/tr\u003e\n",
        "  \u003c/thead\u003e\n",
        "  \u003ctr\u003e\n",
        "    \u003ctd rowspan = 3\u003e\n",
        "      super resolution (ESRGAN)\n",
        "    \u003c/td\u003e\n",
        "    \u003ctd rowspan = 3\u003e\n",
        "      4.8 Mb\n",
        "    \u003c/td\u003e\n",
        "    \u003ctd\u003ePixel 3\u003c/td\u003e\n",
        "    \u003ctd\u003e586.8ms*\u003c/td\u003e\n",
        "      \u003ctd\u003e128.6ms\u003c/td\u003e\n",
        "  \u003c/tr\u003e\n",
        "  \u003ctr\u003e\n",
        "     \u003ctd\u003ePixel 4\u003c/td\u003e\n",
        "    \u003ctd\u003e385.1ms*\u003c/td\u003e\n",
        "      \u003ctd\u003e130.3ms\u003c/td\u003e\n",
        "  \u003c/tr\u003e\n",
        "\n",
        "\u003c/table\u003e\n",
        "\n",
        "**4 threads used*"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "super_resolution.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
